-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
Hi @zhanghang1989 , is there any difference between import mxnet as mx
import time
T = 1000
N = 1000
while 1:
ti = time.time()
a = mx.nd.arange(N)
for i in range(T):
a += 1
mx.nd.waitall()
print('a += b: ', time.time() - ti)
ti = time.time()
a = mx.nd.arange(N)
for i in range(T):
a[:] += 1
mx.nd.waitall()
print('a[:] += b: ', time.time() - ti) Output:
|
@zhanghang1989 The update rule in this PR is the following - mom_data[i] = param_momentum*mom_data[i];
KERNEL_ASSIGN(out_data[i], req, weight_data[i]-mom_data[i]
+(param_momentum+1)*(mom_data[i]
-(param_lr*(param_rescale_grad*grad_data[i]+param_wd*weight_data[i])))); this update rule is same as the following psuedocode -
which when simplified, translates to
( it is the same rule used in keras as well - https://stats.stackexchange.com/questions/179915/whats-the-difference-between-momentum-based-gradient-descent-and-nesterovs-acc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weight update is correct. Please fix the mentum update in the end.
yes, will change the momentum update state |
I am not familiar with symbol API. Just write some pseudocode to show how NAG works :) |
a3500d7
to
b644e5f
Compare
Thanks @zhanghang1989 and @anirudhacharya |
@anirudhacharya perl gpu tests are failing : http://jenkins.mxnet-ci.amazon-ml.com/blue/rest/organizations/jenkins/pipelines/mxnet-validation/pipelines/unix-gpu/branches/master/runs/1029/nodes/304/steps/568/log/?start=0 , |
* fix update rules * readable updates in unit test * mom update
Description
Fixes #15543
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
For review - @zhanghang1989 @apeforest @eric-haibin-lin